# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
"""Tests for NCCL/RCCL"""

import tempfile

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import get_global_func
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.vm import VirtualMachine
from tvm.script import relax as R

_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]


def create_device_target(ccl):
    if ccl == "nccl":
        dev = tvm.cuda(0)
    else:
        dev = tvm.rocm(0)
    target = tvm.target.Target.from_device(dev)
    return (dev, target)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_init(session_kind, ccl):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_allreduce(session_kind, ccl):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
    array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
    d_array = sess.empty((3, 4), "float32")
    d_array.debug_copy_from(0, array_1)
    d_array.debug_copy_from(1, array_2)
    for op, np_op in [  # pylint: disable=invalid-name
        ("sum", np.add),
        ("prod", np.multiply),
        ("min", np.minimum),
        ("max", np.maximum),
        ("avg", lambda a, b: (a + b) * 0.5),
    ]:
        dst_array = sess.empty((3, 4), "float32")
        sess.allreduce(d_array, dst_array, op=op)
        result = dst_array.debug_get_from_remote(0).numpy()
        expected = np_op(array_1, array_2)
        np.testing.assert_equal(result, expected)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_allreduce(session_kind, ccl):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
    array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
    array_3 = np.arange(30, dtype="float32").reshape(5, 6)
    array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6)
    d_array_1 = sess.empty((3, 4), "float32")
    d_array_2 = sess.empty((5, 6), "float32")
    d_array_1.debug_copy_from(0, array_1)
    d_array_1.debug_copy_from(1, array_2)
    d_array_2.debug_copy_from(2, array_3)
    d_array_2.debug_copy_from(3, array_4)
    for op, np_op in [  # pylint: disable=invalid-name
        ("sum", np.add),
        ("prod", np.multiply),
        ("min", np.minimum),
        ("max", np.maximum),
        ("avg", lambda a, b: (a + b) * 0.5),
    ]:
        dst_array_1 = sess.empty((3, 4), "float32")
        dst_array_2 = sess.empty((5, 6), "float32")
        sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True)
        sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True)
        result_1 = dst_array_1.debug_get_from_remote(0).numpy()
        result_2 = dst_array_2.debug_get_from_remote(2).numpy()
        expected_1 = np_op(array_1, array_2)
        expected_2 = np_op(array_3, array_4)
        np.testing.assert_equal(result_1, expected_1)
        np.testing.assert_equal(result_2, expected_2)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_allgather(session_kind, ccl):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array = np.arange(36, dtype="float32")
    d_src = sess.empty((3, 3, 2), "float32")
    d_dst = sess.empty((3, 4, 3), "float32")
    d_src.debug_copy_from(0, array[:18])
    d_src.debug_copy_from(1, array[18:])
    sess.allgather(d_src, d_dst)
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array.reshape(3, 4, 3),
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(1).numpy(),
        array.reshape(3, 4, 3),
    )


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_allgather(session_kind, ccl):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(36, dtype="float32")
    array_2 = np.arange(48, dtype="float32")
    d_src_1 = sess.empty((3, 3, 2), "float32")
    d_dst_1 = sess.empty((3, 4, 3), "float32")
    d_src_2 = sess.empty((2, 4, 3), "float32")
    d_dst_2 = sess.empty((2, 6, 4), "float32")
    d_src_1.debug_copy_from(0, array_1[:18])
    d_src_1.debug_copy_from(1, array_1[18:])
    d_src_2.debug_copy_from(2, array_2[:24])
    d_src_2.debug_copy_from(3, array_2[24:])
    sess.allgather(d_src_1, d_dst_1, in_group=True)
    sess.allgather(d_src_2, d_dst_2, in_group=True)
    np.testing.assert_equal(
        d_dst_1.debug_get_from_remote(0).numpy(),
        array_1.reshape(3, 4, 3),
    )
    np.testing.assert_equal(
        d_dst_1.debug_get_from_remote(1).numpy(),
        array_1.reshape(3, 4, 3),
    )
    np.testing.assert_equal(
        d_dst_2.debug_get_from_remote(2).numpy(),
        array_2.reshape(2, 6, 4),
    )
    np.testing.assert_equal(
        d_dst_2.debug_get_from_remote(3).numpy(),
        array_2.reshape(2, 6, 4),
    )


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_broadcast(session_kind, ccl, use_explicit_output):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array = np.arange(12, dtype="float32").reshape(3, 4)

    if use_explicit_output:
        src_array = sess.empty((3, 4), "float32", worker0_only=True)
        src_array.debug_copy_from(0, array)
        dst_array = sess.empty((3, 4), "float32")
        sess.broadcast_from_worker0(src_array, dst_array)
    else:
        dst_array = sess.broadcast(array)

    result = dst_array.debug_get_from_remote(1).numpy()
    np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_broadcast(session_kind, ccl):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
    array_2 = np.multiply(array_1, -1)

    src_array = sess.empty((3, 4), "float32", worker0_only=True, in_group=True)
    src_array.debug_copy_from(0, array_1)
    src_array.debug_copy_from(2, array_2)
    dst_array = sess.empty((3, 4), "float32")
    sess.broadcast_from_worker0(src_array, dst_array)

    result_1 = dst_array.debug_get_from_remote(1).numpy()
    np.testing.assert_equal(result_1, array_1)

    result_3 = dst_array.debug_get_from_remote(3).numpy()
    np.testing.assert_equal(result_3, array_2)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
@pytest.mark.parametrize("use_explicit_output", [True, False])
def test_scatter(session_kind, ccl, use_explicit_output, capfd):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array = np.arange(36, dtype="float32").reshape(2, 6, 3)

    if use_explicit_output:
        d_src = sess.empty((2, 6, 3), "float32", worker0_only=True)
        d_dst = sess.empty((6, 3), "float32")
        d_src.debug_copy_from(0, array)
        sess.scatter_from_worker0(d_src, d_dst)
    else:
        d_dst = sess.scatter(array)

    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array[0, :, :],
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(1).numpy(),
        array[1, :, :],
    )

    captured = capfd.readouterr()
    assert (
        not captured.err
    ), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_scatter(session_kind, ccl, capfd):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(36, dtype="float32").reshape(2, 6, 3)
    array_2 = np.multiply(array_1, -1)

    d_src = sess.empty((2, 6, 3), "float32", worker0_only=True, in_group=True)
    d_src.debug_copy_from(0, array_1)
    d_src.debug_copy_from(2, array_2)
    d_dst = sess.empty((6, 3), "float32")
    sess.scatter_from_worker0(d_src, d_dst)

    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array_1[0, :, :],
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(1).numpy(),
        array_1[1, :, :],
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(2).numpy(),
        array_2[0, :, :],
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(3).numpy(),
        array_2[1, :, :],
    )

    captured = capfd.readouterr()
    assert (
        not captured.err
    ), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_scatter_with_implicit_reshape(session_kind, ccl, capfd):
    """Scatter may perform an implicit reshape

    Scattering elements to the workers requires the total number of
    elements to be divisible by the number of workers.  It does not
    necessarily correspond to scattering across the outermost
    dimension.  Here, the number of workers (2) and the outermost
    dimension (3) are not divisible, but the scatter may still be
    performed.

    This is only allowed when the caller explicitly uses the
    `sess.scatter_from_worker0` method, and is not allowed in
    `sess.scatter` method.  Because the `sess.scatter` method may
    perform an allocation on the disco workers, it requires that the
    scatter occur across the outermost dimension.

    """
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array = np.arange(36, dtype="float32").reshape(3, 4, 3)

    d_src = sess.empty((3, 4, 3), "float32", worker0_only=True)
    d_dst = sess.empty((3, 3, 2), "float32")
    d_src.debug_copy_from(0, array)
    sess.scatter_from_worker0(d_src, d_dst)

    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array.flat[:18].reshape(3, 3, 2),
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(1).numpy(),
        array.flat[18:].reshape(3, 3, 2),
    )

    captured = capfd.readouterr()
    assert (
        not captured.err
    ), "No warning messages should be generated from disco.Session.scatter_from_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_gather(session_kind, ccl, capfd):
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    array = np.arange(36, dtype="float32")
    d_src = sess.empty((3, 3, 2), "float32")
    d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True)
    d_src.debug_copy_from(0, array[:18])
    d_src.debug_copy_from(1, array[18:])
    sess.gather_to_worker0(d_src, d_dst)
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array.reshape(3, 4, 3),
    )

    captured = capfd.readouterr()
    assert (
        not captured.err
    ), "No warning messages should be generated from disco.Session.gather_to_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_group_gather(session_kind, ccl, capfd):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(36, dtype="float32")
    array_2 = np.multiply(array_1, -1)
    d_src = sess.empty((3, 3, 2), "float32")
    d_dst = sess.empty((3, 4, 3), "float32", worker0_only=True, in_group=True)
    d_src.debug_copy_from(0, array_1[:18])
    d_src.debug_copy_from(1, array_1[18:])
    d_src.debug_copy_from(2, array_2[:18])
    d_src.debug_copy_from(3, array_2[18:])
    sess.gather_to_worker0(d_src, d_dst)
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(0).numpy(),
        array_1.reshape(3, 4, 3),
    )
    np.testing.assert_equal(
        d_dst.debug_get_from_remote(2).numpy(),
        array_2.reshape(3, 4, 3),
    )

    captured = capfd.readouterr()
    assert (
        not captured.err
    ), "No warning messages should be generated from disco.Session.gather_to_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array_1 = np.arange(12, dtype="float32").reshape(3, 4)
    array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
    d_array = sess.empty((3, 4), "float32")
    d_array.debug_copy_from(0, array_1)
    d_array.debug_copy_from(1, array_2)
    sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")(
        d_array
    )

    result_1 = d_array.debug_get_from_remote(2).numpy()
    result_2 = d_array.debug_get_from_remote(3).numpy()
    np.testing.assert_equal(result_1, array_1)
    np.testing.assert_equal(result_2, array_2)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_worker2_send_to_worker0(session_kind, ccl):
    devices = [0, 1, 2, 3]
    sess = session_kind(num_workers=len(devices), num_groups=2)
    sess.init_ccl(ccl, *devices)

    array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
    d_array = sess.empty((3, 4), "float32")
    d_array.debug_copy_from(2, array)
    sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array)

    result = d_array.debug_get_from_remote(0).numpy()
    np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl):  # pylint: disable=too-many-locals
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    # pylint: disable=invalid-name
    @tvm.script.ir_module
    class MLP:  # pylint: disable=too-few-public-methods
        @R.function
        def main(
            x: R.Tensor((128, 128), "float32"),
            W1: R.Tensor((128, 128), "float32"),
            W2: R.Tensor((128, 128), "float32"),
        ) -> R.Tensor((128, 128), "float32"):
            R.func_attr({"global_symbol": "main"})
            with R.dataflow():
                lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1)
                lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0)
                lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
                R.output(lv2)
            return lv2

    @tvm.script.ir_module
    class ShardedMLP:  # pylint: disable=too-few-public-methods
        @R.function
        def main(
            x: R.Tensor((128, 128), "float32"),
            W1: R.Tensor((128, 64), "float32"),  # shard along axis 1
            W2: R.Tensor((64, 128), "float32"),  # shard along axis 0
        ) -> R.Tensor((128, 128), "float32"):
            R.func_attr({"global_symbol": "main"})
            with R.dataflow():
                broadcast_x: R.Tensor((128, 128), "float32") = R.ccl.broadcast_from_worker0(x)
                lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
                lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
                lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
                lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2, "sum")
                R.output(lv3)
            return lv3

    # pylint: enable=invalid-name
    dev, target = create_device_target(ccl)

    def relax_build(mod, target):
        with target:
            mod = rx.get_pipeline("zero")(mod)  # pylint: disable=no-value-for-parameter
            mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            )(mod)
            return tvm.compile(mod, target=target)

    # pylint: disable=invalid-name
    X = np.random.randn(128, 128).astype("float32")
    W1 = np.random.randn(128, 128).astype("float32")
    W2 = np.random.randn(128, 128).astype("float32")
    Y_expected = VirtualMachine(relax_build(MLP, target), device=dev)["main"](
        tvm.runtime.tensor(X, device=dev),
        tvm.runtime.tensor(W1, device=dev),
        tvm.runtime.tensor(W2, device=dev),
    ).numpy()

    with tempfile.TemporaryDirectory() as tmpdir:
        path = tmpdir + "/test.so"
        relax_build(ShardedMLP, target).export_library(path)

        mod = sess.load_vm_module(path)

        d_X = sess.empty((128, 128), "float32")
        d_W1 = sess.empty((128, 64), "float32")
        d_W2 = sess.empty((64, 128), "float32")

        d_X.debug_copy_from(0, X)
        d_W1.debug_copy_from(0, W1[:, :64])
        d_W1.debug_copy_from(1, W1[:, 64:])
        d_W2.debug_copy_from(0, W2[:64, :])
        d_W2.debug_copy_from(1, W2[64:, :])
        d_Y = mod["main"](d_X, d_W1, d_W2)
        Y_result = tvm.runtime.empty((128, 128), "float32", device=dev)
        sess.copy_from_worker_0(Y_result, d_Y)
        sess.sync_worker_0()
        Y_result = Y_result.numpy()
    # pylint: enable=invalid-name
    np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-4, atol=1e-4)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_attention(session_kind, ccl):  # pylint: disable=too-many-locals,too-many-statements
    devices = [0, 1]
    sess = session_kind(num_workers=len(devices))
    sess.init_ccl(ccl, *devices)

    # pylint: disable=invalid-name
    @tvm.script.ir_module
    class Attention:  # pylint: disable=too-few-public-methods
        @R.function
        def main(  # pylint: disable=too-many-locals
            x: R.Tensor((1, 10, 128), "float32"),
            Wq: R.Tensor((128, 512), "float32"),
            Wk: R.Tensor((128, 512), "float32"),
            Wv: R.Tensor((128, 512), "float32"),
            Wo: R.Tensor((512, 128), "float32"),
        ) -> R.Tensor((128, 128), "float32"):
            R.func_attr({"global_symbol": "main"})
            with R.dataflow():
                # q
                lv0: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wq)
                lv1: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv0, [1, 10, 8, 64])
                lv2: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv1, [0, 2, 1, 3])
                # k
                lv3: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wk)
                lv4: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv3, [1, 10, 8, 64])
                lv5: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv4, [0, 2, 1, 3])
                # v
                lv6: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wv)
                lv7: R.Tensor((1, 10, 8, 64), "float32") = R.reshape(lv6, [1, 10, 8, 64])
                lv8: R.Tensor((1, 8, 10, 64), "float32") = R.permute_dims(lv7, [0, 2, 1, 3])
                # softmax(q @ k / sqrt(dk))
                lv9: R.Tensor((1, 8, 64, 10), "float32") = R.permute_dims(lv5, [0, 1, 3, 2])
                lv10: R.Tensor((1, 8, 10, 10), "float32") = R.matmul(lv2, lv9)
                lv11: R.Tensor((1, 8, 10, 10), "float32") = R.multiply(
                    lv10, R.const(1 / 8, "float32")
                )
                lv12: R.Tensor((1, 8, 10, 10), "float32") = R.nn.softmax(lv11, axis=-1)
                # attn_weight @ v
                lv13: R.Tensor((1, 8, 10, 64), "float32") = R.matmul(lv12, lv8)
                lv14: R.Tensor((1, 10, 8, 64), "float32") = R.permute_dims(lv13, [0, 2, 1, 3])
                lv15: R.Tensor((1, 10, 512), "float32") = R.reshape(lv14, [1, 10, 512])
                # attn_output @ o
                lv16: R.Tensor((1, 10, 128), "float32") = R.matmul(lv15, Wo)
                R.output(lv16)
            return lv16

    @tvm.script.ir_module
    class ShardedAttention:  # pylint: disable=too-few-public-methods
        @R.function
        def main(  # pylint: disable=too-many-locals
            x: R.Tensor((1, 10, 128), "float32"),
            Wq: R.Tensor((128, 256), "float32"),  # shard along axis 1
            Wk: R.Tensor((128, 256), "float32"),  # shard along axis 1
            Wv: R.Tensor((128, 256), "float32"),  # shard along axis 1
            Wo: R.Tensor((256, 128), "float32"),  # shard along axis 0
        ) -> R.Tensor((128, 128), "float32"):
            R.func_attr({"global_symbol": "main"})
            with R.dataflow():
                broadcast_x: R.Tensor((1, 10, 128), "float32") = R.ccl.broadcast_from_worker0(x)
                # q
                lv0: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wq)
                lv1: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv0, [1, 10, 4, 64])
                lv2: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv1, [0, 2, 1, 3])
                # k
                lv3: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wk)
                lv4: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv3, [1, 10, 4, 64])
                lv5: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv4, [0, 2, 1, 3])
                # v
                lv6: R.Tensor((1, 10, 256), "float32") = R.matmul(broadcast_x, Wv)
                lv7: R.Tensor((1, 10, 4, 64), "float32") = R.reshape(lv6, [1, 10, 4, 64])
                lv8: R.Tensor((1, 4, 10, 64), "float32") = R.permute_dims(lv7, [0, 2, 1, 3])
                # softmax(q @ k / sqrt(dk))
                lv9: R.Tensor((1, 4, 64, 10), "float32") = R.permute_dims(lv5, [0, 1, 3, 2])
                lv10: R.Tensor((1, 4, 10, 10), "float32") = R.matmul(lv2, lv9)
                lv11: R.Tensor((1, 4, 10, 10), "float32") = R.multiply(
                    lv10, R.const(1 / 8, "float32")
                )
                lv12: R.Tensor((1, 4, 10, 10), "float32") = R.nn.softmax(lv11, axis=-1)
                # attn_weight @ v
                lv13: R.Tensor((1, 4, 10, 64), "float32") = R.matmul(lv12, lv8)
                lv14: R.Tensor((1, 10, 4, 64), "float32") = R.permute_dims(lv13, [0, 2, 1, 3])
                lv15: R.Tensor((1, 10, 256), "float32") = R.reshape(lv14, [1, 10, 256])
                # attn_output @ o
                lv16: R.Tensor((1, 10, 128), "float32") = R.matmul(lv15, Wo)
                lv17: R.Tensor((1, 10, 128), "float32") = R.ccl.allreduce(lv16, "sum")
                R.output(lv17)
            return lv17

    # pylint: enable=invalid-name
    dev, target = create_device_target(ccl)

    def relax_build(mod, target):
        with target:
            mod = rx.get_pipeline("zero")(mod)  # pylint: disable=no-value-for-parameter
            mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            )(mod)
            return tvm.compile(mod, target=target)

    # pylint: disable=invalid-name
    X = np.random.randn(1, 10, 128).astype("float32")
    Wq = np.random.randn(128, 512).astype("float32")
    Wk = np.random.randn(128, 512).astype("float32")
    Wv = np.random.randn(128, 512).astype("float32")
    Wo = np.random.randn(512, 128).astype("float32")
    Y_expected = VirtualMachine(relax_build(Attention, target), device=dev)["main"](
        tvm.runtime.tensor(X, device=dev),
        tvm.runtime.tensor(Wq, device=dev),
        tvm.runtime.tensor(Wk, device=dev),
        tvm.runtime.tensor(Wv, device=dev),
        tvm.runtime.tensor(Wo, device=dev),
    ).numpy()

    with tempfile.TemporaryDirectory() as tmpdir:
        path = tmpdir + "/test.so"
        relax_build(ShardedAttention, target).export_library(path)

        mod = sess.load_vm_module(path)

        d_X = sess.empty((1, 10, 128), "float32")
        d_Wq = sess.empty((128, 256), "float32")
        d_Wk = sess.empty((128, 256), "float32")
        d_Wv = sess.empty((128, 256), "float32")
        d_Wo = sess.empty((256, 128), "float32")

        d_X.debug_copy_from(0, X)
        d_Wq.debug_copy_from(0, Wq[:, :256])
        d_Wq.debug_copy_from(1, Wq[:, 256:])
        d_Wk.debug_copy_from(0, Wk[:, :256])
        d_Wk.debug_copy_from(1, Wk[:, 256:])
        d_Wv.debug_copy_from(0, Wv[:, :256])
        d_Wv.debug_copy_from(1, Wv[:, 256:])
        d_Wo.debug_copy_from(0, Wo[:256, :])
        d_Wo.debug_copy_from(1, Wo[256:, :])
        d_Y = mod["main"](d_X, d_Wq, d_Wk, d_Wv, d_Wo)
        Y_result = tvm.runtime.empty((1, 10, 128), "float32", device=dev)
        sess.copy_from_worker_0(Y_result, d_Y)
        sess.sync_worker_0()
        Y_result = Y_result.numpy()
    # pylint: enable=invalid-name
    np.testing.assert_allclose(Y_result, Y_expected, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
    tvm.testing.main()
